import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax, degree

class GATConv_pre(MessagePassing):
    def __init__(self, in_channels, out_channels, alpha_threshold, reattn, central, self_loops=False):
        super(GATConv_pre, self).__init__(aggr='add')
        self.self_loops = self_loops
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.alpha_threshold = float(alpha_threshold[1:])
        self.compare = alpha_threshold[0]
        self.reattn = reattn
        self.central = central
        self.fcy = torch.nn.Linear(out_channels, out_channels)

    def forward(self, x, y, edge_index, size=None):
        if self.self_loops:
            edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        return self.propagate(edge_index, size=size, x=x, y=y)
    
    def message(self, edge_index, edge_index_i, edge_index_j, x, y, x_i, x_j, y_i, y_j, size_i, size_j):  
        if self.compare == '>' or self.compare == '<':
            # hard modulation 
            self.alpha = torch.mul(x_i, x_j).sum(dim=-1)
            self.alpha = softmax(src=self.alpha, index=edge_index_i, num_nodes=size_i)  # equation 1
            tmp = self.alpha.unsqueeze(1).expand_as(x_j)
            if self.compare == '>':
                x_j = torch.where(tmp>self.alpha_threshold, (x_j+y_j)/2, x_j) 
            elif self.compare == '<':
                x_j = torch.where(tmp<self.alpha_threshold, (x_j+y_j)/2, x_j)  # equation 2
        else:
            # soft modulation
            gating = torch.sigmoid(self.fcy(y_j))  # equation 3
            x_j = (x_j + gating*(y_j))/2  # equation 4
        
        if self.reattn:
            self.alpha = torch.mul(x_i, x_j).sum(dim=-1)
            self.alpha = softmax(src=self.alpha, index=edge_index_i, num_nodes=size_i)  # equation 5
        
        return x_j*(self.alpha.view(-1,1))

    def update(self, aggr_out):
        return aggr_out
    
class GATConv_ori(MessagePassing):
    def __init__(self, in_channels, out_channels, self_loops=False):
        super(GATConv_ori, self).__init__(aggr='add')
        self.self_loops = self_loops
        self.in_channels = in_channels
        self.out_channels = out_channels

    def forward(self, x, edge_index, size=None):
        edge_index, _ = remove_self_loops(edge_index)
        if self.self_loops:
            edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        return self.propagate(edge_index, size=size, x=x)

    def message(self, edge_index_i, x_i, x_j, size_i):
        self.alpha = torch.mul(x_i, x_j).sum(dim=-1)
        self.alpha = softmax(src=self.alpha, index=edge_index_i, num_nodes=size_i)
        return x_j*self.alpha.view(-1,1)

    def update(self, aggr_out):
        return aggr_out
        
           
class CGCN(nn.Module):
    def __init__(self, num_user, num_item, dim_C, num_routing, has_act, has_norm, alpha_threshold, reattn, central):
        super(CGCN, self).__init__()
        self.num_user = num_user
        self.num_item = num_item
        self.num_routing = num_routing
        self.has_act = has_act
        self.has_norm = has_norm
        self.dim_C = dim_C
        # self.conv_embed_1 = GATConv(self.dim_C, self.dim_C, alpha_threshold, reattn, central)
        self.conv_embed_1 = GATConv_pre(
            self.dim_C, self.dim_C, alpha_threshold, reattn, central)
        self.conv_embed_2 = GATConv_ori(self.dim_C, self.dim_C)

    def forward(self, preference, features, preference2, features2, edge_index):

        x = torch.cat((preference, features), dim=0)
        y = torch.cat((preference2, features2), dim=0)

        # modulation
        edge_index = torch.cat((edge_index, edge_index[[1,0]]), dim=1)
        x = self.conv_embed_1(x, y, edge_index)
        # x = x + x_hat
        if self.has_norm:
            x = F.normalize(x)

        # modal-specific convolution
        for i in range(self.num_routing):
            x_hat_1 = self.conv_embed_2(x, edge_index)
            if self.has_act:
                x_hat_1 = F.leaky_relu_(x_hat_1)
            x += x_hat_1
            if self.has_norm:
                x = F.normalize(x)
        return x

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels, aggr):
        super(GCNConv, self).__init__(aggr='add')  

    def forward(self, x, edge_index):
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_j, edge_index, size):
        row, col = edge_index
        deg = degree(row, size[0], dtype=x_j.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        return aggr_out
    
class RealGCN(torch.nn.Module):
    def __init__(self, num_user, num_item, dim_E, aggr_mode, has_act, has_norm):
        super(RealGCN, self).__init__()
        self.num_user = num_user
        self.num_item = num_item
        self.dim_E = dim_E
        self.aggr_mode = aggr_mode
        self.has_act = has_act
        self.has_norm = has_norm
        self.id_embedding = nn.Parameter(nn.init.xavier_normal_(torch.rand((num_user+num_item, dim_E))))
        self.conv_embed_1 = GCNConv(dim_E, dim_E, aggr=aggr_mode)         
        self.conv_embed_2 = GCNConv(dim_E, dim_E, aggr=aggr_mode)
        self.conv_embed_3 = GCNConv(dim_E, dim_E, aggr=aggr_mode)

    def forward(self, edge_index):
        x = self.id_embedding   
        edge_index = torch.cat((edge_index, edge_index[[1,0]]), dim=1)  
        if self.has_norm:
            x = F.normalize(x) 

        x_hat_1 = self.conv_embed_1(x, edge_index)       

        if self.has_act:
            x_hat_1 = F.leaky_relu_(x_hat_1)

        x_hat_2 = self.conv_embed_2(x_hat_1, edge_index) 
        if self.has_act:
            x_hat_2 = F.leaky_relu_(x_hat_2)
        x_hat_3 = self.conv_embed_3(x_hat_1, edge_index) 
        return x + x_hat_1 + x_hat_2 + x_hat_3   

class EgoGCN(nn.Module):
    def __init__(self, n_users, n_items, embedding_dim, weight_size, dropout_list, image_feats=None, text_feats=None, adj=None, has_norm=True, has_act=False, num_routing=2, alpha_threshold='>0.5', reattn=True, central='central_item',edge_index=None,aggr_mode='add'):
        super().__init__()
        self.n_users = n_users
        self.n_items = n_items
        self.embedding_dim = embedding_dim
        self.user_embedding = nn.Embedding(n_users, embedding_dim)
        self.item_embedding = nn.Embedding(n_items, embedding_dim)
        self.has_norm = has_norm
        self.alpha_threshold = alpha_threshold
        self.reattn = reattn
        self.central = central
        self.edge_index = torch.tensor(edge_index).t().contiguous().cuda()

        nn.init.xavier_uniform_(self.user_embedding.weight)
        nn.init.xavier_uniform_(self.item_embedding.weight)

        self.v_feat = torch.tensor(image_feats, dtype=torch.float).cuda()
        self.t_feat = torch.tensor(text_feats, dtype=torch.float).cuda()

        dim_E = dim_C = embedding_dim
        self.MLPv = nn.Linear(self.v_feat.size(1), dim_C)
        self.MLPt = nn.Linear(self.t_feat.size(1), dim_C)

        self.preferencev = nn.Parameter(
            nn.init.xavier_normal_(torch.rand((n_users, dim_C))))
        self.preferencet = nn.Parameter(
            nn.init.xavier_normal_(torch.rand((n_users, dim_C))))

        if image_feats is not None:
            self.v_gcn = CGCN(n_users, n_items, dim_C, num_routing, has_act,
                              has_norm, self.alpha_threshold, self.reattn, self.central)

        if text_feats is not None:
            self.t_gcn = CGCN(n_users, n_items, dim_C, num_routing, has_act,
                              has_norm, self.alpha_threshold, self.reattn, self.central)

        self.id_gcn_real = RealGCN(n_users, n_items, dim_E, aggr_mode, has_act, has_norm)

    def forward(self,training=1):
        # central item
        edge_index = self.edge_index # user->item
        if self.central == 'central_user':
            edge_index = edge_index[[1,0]] # item->user
        preferencev = self.preferencev
        preferencet = self.preferencet
        if training == 1:
            features_v = F.leaky_relu_(self.MLPv(self.v_feat))
            features_t = F.leaky_relu_(self.MLPt(self.t_feat))
        elif training == 2:
            features_v = F.leaky_relu_(self.MLPv(self.v_feat.mean(dim=0).tile(self.n_items,1)))
            preferencev = preferencev.mean(dim=0).tile(self.n_users,1)
            features_t = F.leaky_relu_(self.MLPt(self.t_feat))
        elif training == 3:
            features_v = F.leaky_relu_(self.MLPv(self.v_feat))
            features_t = F.leaky_relu_(self.MLPt(self.t_feat.mean(dim=0).tile(self.n_items,1)))
            preferencet = preferencet.mean(dim=0).tile(self.n_users,1)

        if self.has_norm:
            preferencev = F.normalize(preferencev)
            preferencet = F.normalize(preferencet)
            features_v = F.normalize(features_v)
            features_t = F.normalize(features_t)
        v_rep = self.v_gcn(preferencev, features_v,
                           preferencet, features_t, edge_index)
        t_rep = self.t_gcn(preferencet, features_t,
                           preferencev, features_v, edge_index)
        id_rep = self.id_gcn_real(edge_index) 
        

        rep = torch.cat((id_rep, v_rep, t_rep), dim=1)
    
        user_rep = rep[:self.n_users]
        item_rep = rep[self.n_users:]
        return user_rep, item_rep